"""	Maintained by Roger
	"""


import numpy as np
import scipy.sparse as sparse
from algorithms.linalg import np_conserved as npc
from models.model import model
from mps.mps import iMPS
from mps.impo import MPO
from tools.string import joinstr
from tools.math import toiterable
from tools.math import tonparray
from models.model import set_var
from models.model import any_nonzero


class tri_ising_model(model):
	"""	Spin-1/2 Tri Ising model
		Hamiltonian:
			H = sum of onsites terms + sum of hopping terms.
			Note: we are really using the extra hoppings to define the KE part of the Hamiltonian
        
			S_, T_, and U_ operators are spin matrices (not Pauli) with eigenvalues +-0.5
		"""
	# negative J/K is ferromagnetic
	# positive J/K is antiferromagnetic
	
	def __init__(self, pars):
		super(tri_ising_model, self).__init__()
		self.verbose = set_var(pars, 'verbose', 0, force_verbose=0)
		verbose = self.verbose
		if verbose > 0:
			print "tri-Ising model"
			print "\tverbose:", verbose
		self.L = L = set_var(pars, 'L', 1)
		self.bc = set_var(pars, 'bc', 'periodic')
		#self.nuS = set_var(pars, 'nuS', (0,1))		# filling fraction, average <S_z>
		#self.nuT = set_var(pars, 'nuT', (0,1))		# filling fraction, average <T_z>
		#self.nuU = set_var(pars, 'nuU', (0,1))    # filling fraction, average <U_z>
		self.dtype = set_var(pars, 'dtype', float)

		self.site_ops = set_var(pars, 'extra_site_ops', [])
		assert type(self.site_ops) == list


		self.extra_onsite_terms = set_var(pars, 'extra_onsite', [ [] for i in range(self.L) ])
		if len(self.extra_onsite_terms)!= self.L: raise ValueError
		for onsite_term in self.extra_onsite_terms:
			for h in onsite_term:
				if not isinstance(h, tuple) or len(h) != 2: raise ValueError, "Bad onsite term: " + str(h)
		##	extra_hoppings should be a list of list of 3-tuples: e.g. ('SxTx', 'Ty', -1.)
		self.extra_hoppings = set_var(pars, 'extra_hoppings', [ [] for i in range(self.L) ])
		if len(self.extra_hoppings) != self.L: raise ValueError
		for hoppings in self.extra_hoppings:
			for h in hoppings:
				if not isinstance(h, tuple) or len(h) != 3: raise ValueError, "Bad hopping term: " + str(h)
		# TODO, determine if complex is needed

		self.conserve_Sz = set_var(pars, 'conserve_Sz', True, force_verbose=0)

		## conserve_Sz overrides conserve_SzZ2
		if self.conserve_Sz:
			self.num_q = 3
		else:
			self.num_q = 0
		if verbose > 0:
			if self.conserve_Sz:
				print "\tconserving Sz, (num_q = %s)" % self.num_q
			else:
				print "\tno conserved quantity, (num_q = %s)" % self.num_q
    
		self.fracture_mpo = set_var(pars, 'fracture_mpo', False)		## don't use this

		self.construct_mpo()



	def construct_mpo(self):
		""" Clears current H_mpo, and constructs L copies of H_mpo from local temp W
			"""
		verbose = self.verbose
		L = self.L
		self.H = []
		self.w_bond = []
		self.v_bond = []
		d = self.d = 8
		num_q = self.num_q

	##	Construct spin operators
	##	Hilbert space: (Sup Hup, Sup Tdn, Sdn Tup, Sdn Tdn)
		self.upupup = 0
		self.upupdn = 1
		self.updnup = 2
		self.updndn = 3
		self.dnupup = 4
		self.dnupdn = 5
		self.dndnup = 6
		self.dndndn = 7
		self.states = ['upupup', 'upupdn', 'updnup', 'updndn', 'dnupup', 'dnupdn', 'dndnup', 'dndndn']
                
	##	First make the Pauli matrices
		I2 = np.eye(2, dtype=int)
		Px = np.array([[0,1],[1,0]])
		Py = np.array([[0,-1j],[1j,0]])
		Pz = np.array([[1,0],[0,-1]])
		Sp = np.array([[0,1.],[0,0]])
		Sm = np.array([[0,0],[1.,0]])
		PauliMatrices = [('x',Px), ('y',Py), ('z',Pz)]
		
	##	Make the 8x8 operators  --  these are all duplicates
		Id = self.Id = np.eye(d, dtype=int)
		Sx = self.Sx = reduce(np.kron, [Px, I2, I2]) * 0.5 		# check
		Sy = self.Sy = reduce(np.kron, [Py, I2, I2]) * 0.5 
		Sz = self.Sz = reduce(np.kron, [Pz, I2, I2]) * 0.5 
		Tx = self.Tx = reduce(np.kron, [I2, Px, I2]) * 0.5
		Ty = self.Ty = reduce(np.kron, [I2, Py, I2]) * 0.5
		Tz = self.Tz = reduce(np.kron, [I2, Pz, I2]) * 0.5
		Ux = self.Ux = reduce(np.kron, [I2, I2, Px]) * 0.5
		Uy = self.Uy = reduce(np.kron, [I2, I2, Py]) * 0.5
		Uz = self.Uz = reduce(np.kron, [I2, I2, Pz]) * 0.5
		# make Ux, Uy, Uz, add it to the list on the nxt line
		self.site_ops = self.site_ops + ['Id', 'Sz', 'Sx', 'Sy', 'Tx', 'Ty', 'Tz', 'Ux', 'Uy', 'Uz']

	#I don't think we need this
	#	# make S_a T_b type operators (9 of them)
	#	for a,Pa in PauliMatrices:
	#		for b,Pb in PauliMatrices:
	#			attrstr = "S" + a + "T" + b
	#			setattr(self, attrstr, np.kron(Pa, Pb) * 0.25)
	#			self.site_ops.append(attrstr)
		
	##	Pauli matrices (same as above by with Pauli normalization)
		#self.pII = Id
		#self.pXI = 2 * Sx; self.pYI = 2 * Sy; self.pZI = 2 * Sz
		#self.pIX = 2 * Tx; self.pIY = 2 * Ty; self.pIZ = 2 * Tz
		#self.pXX = np.kron(Px, Px); self.pXY = np.kron(Px, Py); self.pXZ = np.kron(Px, Pz)
		#self.pYX = np.kron(Py, Px); self.pYY = np.kron(Py, Py); self.pYZ = np.kron(Py, Pz)
		#self.pZX = np.kron(Pz, Px); self.pZY = np.kron(Pz, Py); self.pZZ = np.kron(Pz, Pz)
		#self.site_ops = self.site_ops + ['pII', 'pXI', 'pYI', 'pZI', 'pIX', 'pIY', 'pIZ']
		#self.site_ops = self.site_ops + ['pXX', 'pXY', 'pXZ', 'pYX', 'pYY', 'pYZ', 'pZX', 'pZY', 'pZZ']
		Pdict = {'I':I2, 'X':Px, 'Y':Py, 'Z':Pz, 'P':Sp, 'M':Sm}
		for n1,m1 in Pdict.iteritems():
			for n2,m2 in Pdict.iteritems():
				for n3,m3 in Pdict.iteritems():
					Opname = 'p' + n1 + n2 + n3
					Op = reduce(np.kron, [m1, m2, m3])
					setattr(self,Opname, Op)
					self.site_ops.append(Opname)


#	##	Make bond operators
#	##		indices iL' iR' iL iR
#		self.SxSx = np.kron(Sx, Sx).reshape((d,d,d,d))
#		self.SySy = np.kron(Sy, Sy).real.reshape((d,d,d,d))
#		self.SzSz = np.kron(Sz, Sz).reshape((d,d,d,d))
#		self.TxTx = np.kron(Tx, Tx).reshape((d,d,d,d))
#		self.TyTy = np.kron(Ty, Ty).real.reshape((d,d,d,d))
#		self.TzTz = np.kron(Tz, Tz).reshape((d,d,d,d))
#		self.bond_ops = ['SxSx', 'SySy', 'SzSz', 'TxTx', 'TyTy', 'TzTz']
		
	##	Set charges for physical index
		Qp_flat = np.zeros( (d, self.num_q), dtype = int )
		if self.conserve_Sz:
			Qp_flat[:] = np.array([[0,0,0],[0,0,1],[0,1,0],[0,1,1],[1,0,0],[1,0,1],[1,1,0],[1,1,1]])
			self.mod_q = np.array([1,1,1])
		else:
			self.mod_q = np.empty((0,), int)
		assert len(self.mod_q) == self.num_q
		#print "\tQp_flat:", Qp_flat
		self.Qp_flat = [Qp_flat] * L
		
		
		
		
	##	Set H:
	##		one cell Hamiltonian: indices i' i
	##		two cell Hamiltonian, indices iL' iR' iL iR
		self.H_twocell = []		# labelled by bonds
		self.H_onecell = []		# labelled by sites
		for i in range(L):
			H_twocell = np.zeros((d*d, d*d), self.dtype)
			for OL,OR,g in self.extra_hoppings[i]:
				H_twocell += np.real_if_close(g * np.kron(getattr(self, OL), getattr(self, OR)), tol = 50)
				if verbose >= 3 : print "\t[bond {}] Adding hoppings: {} ({} * {})".format(i, g, OL, OR)
				if verbose >= 5 : print joinstr(['\t\t', getattr(self, OL), getattr(self, OR)])
			H_onecell = np.zeros((d,d), self.dtype)
			for Op,g in self.extra_onsite_terms[i]:
				H_onecell += np.real_if_close(g * getattr(self,Op), tol = 50)
				if verbose >= 3: print "\t[site {}] Adding onsite: {} ({})".format(i,g,Op)

			H_onecell = np.real_if_close(H_onecell, tol = 50)
			H_twocell = np.real_if_close(H_twocell, tol = 50)
			self.H_onecell.append(H_onecell)
			self.H_twocell.append(H_twocell.reshape((d,d,d,d)))
		
	##	Constructs H
		for i in range(L):
			i2 = (i+1) % L
			H = self.H_twocell[i] + ( np.kron(Id, self.H_onecell[i2]) + np.kron(self.H_onecell[i], Id) ).reshape((d,d,d,d)) / 2.
			self.H.append(H)

	##	Tack on additional 1-cell term at boundaries if finite
		
		if self.bc == 'finite':
			self.H[0] += np.kron(self.H_onecell[0], Id).reshape((d,d,d,d)) / 2.
			self.H[-2] += np.kron(Id, self.H_onecell[-1]).reshape((d,d,d,d)) / 2.
			self.H[-1] = np.zeros((d,d,d,d))
			
	##	Creates the MPO from H
		self.twocellH_to_H_mpo()
		
		
	##	Get rid of numbers close to zero
		zero_cutoff = 2e-16
		for s in range(L):
			#print "site", s
   			self.H_mpo[s] = np.where(np.abs(self.H_mpo[s]) < zero_cutoff, 0, self.H_mpo[s])
			#print self.Qmpo_flat[s]
			#print joinstr(map(str, H_mpo[s])), "H_mpo: right(chi), left(chi), [ bottom(d), top(d) ]"
			#print "vL,vR:", self.vL[(s-1)%L], self.vR[s]
		
		
	##	Conserve the sh## out of the Model
		self.init_model_for_conservation()
	##	Calculate W and V?
		self.calc_wv()
		
	##	You know... just in case
		for i in range(L):
			self.H_mpo[i].check_sanity()
	##	Whee!! The mpo is done!!
		
	
